# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import re
import subprocess
import time

from common.benchmark_command import BenchmarkCommand

# Regexes for retrieving memory information.
_VMHWM_REGEX = re.compile(r".*?VmHWM:.*?(\d+) kB.*")
_VMRSS_REGEX = re.compile(r".*?VmRSS:.*?(\d+) kB.*")
_RSSFILE_REGEX = re.compile(r".*?RssFile:.*?(\d+) kB.*")


def run_command(benchmark_command: BenchmarkCommand) -> list[float]:
    """Runs `benchmark_command` and polls for memory consumption statistics.
    Args:
      benchmark_command: A `BenchmarkCommand` object containing information on how to run the benchmark and parse the output.
    Returns:
      An array containing values for [`latency`, `vmhwm`, `vmrss`, `rssfile`]
    """
    command = benchmark_command.generate_benchmark_command()
    print("\n\nRunning command:\n" + " ".join(command))
    benchmark_process = subprocess.Popen(
        command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
    )

    # Keep a record of the highest VmHWM corresponding VmRSS and RssFile values.
    vmhwm = 0
    vmrss = 0
    rssfile = 0
    while benchmark_process.poll() is None:
        pid_status = subprocess.run(
            ["cat", "/proc/" + str(benchmark_process.pid) + "/status"],
            capture_output=True,
        )
        output = pid_status.stdout.decode()
        vmhwm_matches = _VMHWM_REGEX.search(output)
        vmrss_matches = _VMRSS_REGEX.search(output)
        rssfile_matches = _RSSFILE_REGEX.search(output)

        if vmhwm_matches and vmrss_matches and rssfile_matches:
            curr_vmhwm = float(vmhwm_matches.group(1))
            if curr_vmhwm > vmhwm:
                vmhwm = curr_vmhwm
                vmrss = float(vmrss_matches.group(1))
                rssfile = float(rssfile_matches.group(1))

        time.sleep(0.5)

    stdout_data, _ = benchmark_process.communicate()

    if benchmark_process.returncode != 0:
        print(
            f"Warning! Benchmark command failed with return code:"
            f" {benchmark_process.returncode}"
        )
        return [0, 0, 0, 0]
    else:
        print(stdout_data.decode())

    latency_ms = benchmark_command.parse_latency_from_output(stdout_data.decode())
    return [latency_ms, vmhwm, vmrss, rssfile]
